{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convolutional Networks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook we'll see how to use convolutional networks for image classification.  We'll start with a simple dense network and gradually improve it until we're getting pretty good results classifying images in the CIFAR 10 data set.  We'll then see how we can avoid building a network from scratch by taking a large, pre-trained net and fine-tuning it to a custom domain.  Much of this content is originally based on Jeremy Howard's [fast.ai lessons](http://course.fast.ai/).  I've combined content from a few different lessons and converted code to use Keras instead of PyTorch.\n",
    "\n",
    "Since Keras comes with a pre-built data loader for CIFAR 10, we can just use that to get started instead of worrying about locating and importing the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from keras.datasets import cifar10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n",
    "x_train.shape, y_train.shape, x_test.shape, y_test.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot a few of the images to get an idea what they look like and confirm that the data loaded correctly.  You'll quickly notice the CIFAR 10 images are very low resolution (32 x 32 images with 3 color channels).  This makes training from scratch quite feasible even on modest compute resources."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_image(index):\n",
    "    image = x_train[index, :, :, :]\n",
    "    plt.imshow(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAHipJREFUeJztnWmMnNd1pt9Ta69cms2luZiLzFEiL5KVhqJYtiJZTqAYHsgajA37h6EfRhgEMRAjmR+CMhh7gvxwMmM7RjLjgI6UKIHjJbEFMzOeGWuEDITEHlnURlGibFES9yabSzd7r+0786OKA6p139vFXqop3/cBCFbfU/e7t+5Xp76q+37nHHN3CCHSI7faExBCrA5yfiESRc4vRKLI+YVIFDm/EIki5xciUeT8QiSKnF+IRJHzC5EohaV0NrN7AXwVQB7AX7r7F2PPX7N2nW/cNESs/E5Ds/BnVC5ntI9HPtdi9zQa+DGNdOQ9FhjNYvNf1BFh9I7NyFiRA0bv/4y/8GsfbAVY7tHi01/caKxXfKiw9eLoGUxOjLV1Zhbt/GaWB/BfAPwagFMAnjazA+7+MuuzcdMQvvinjwRtWZbRsbrL5WB7qauL9sny4T4AUHf+wVBAntryjXB7kU89+m7xAp9HjX3SIP6myDWI1Yu0T73Gj9jIkRcNLMr5Y7eTR281j4yVZZH5k47RD9fIPGLv00Yjslax8Uh7PbpW4Xn84e99ou1xl/K1/zYAR939dXevAvgWgPuWcDwhRAdZivNvA3Dyqr9PtdqEEG8DluL8oe9Tb/meYmb7zOygmR2cuDy2hOGEEMvJUpz/FIAdV/29HcCZ+U9y9/3uPuzuw2vWrl/CcEKI5WQpzv80gL1mttvMSgA+CeDA8kxLCLHSLHq3393rZvZZAP8LTanvEXd/aaF+Gdm1LZT5bnQ1C++iTl+epH2KvXx7OF/spjY475eRneN6ZGe+MVejtrnLs9RW6uJqRQN8x3lqdirYnjN+vL7etdTmkbGyyO62ERlzsbvskSWO7vazcxYTFmI7+rE5xnb72XoAQEZWJVuk6tAuS9L53f0HAH6w5FkIITqO7vATIlHk/EIkipxfiESR8wuRKHJ+IRJlSbv910oja2BiOixF1WpcErtw/mKw/dTpUdon39VLbX39/Gajco5LYkwFrNb53LNandpmJsNrAQDdRT4P5LjMM1kNy5/VKpea9uzeS23vvGEntXXHAquIFBWVqCLBOx4xZjEdkMU5LTbAaJHEpL4ceW1ZRGZdDnTlFyJR5PxCJIqcX4hEkfMLkShyfiESpaO7/VPT0/jR//0xsfGd7xzCQT+zFb4rO9cIKwQAUCxxWz7jn4cNsmE753xHvxHZie4t8d3ybuOnpqvMU401ctVg+/Q0VyQOHnqO2kYvvCVK+/+zZ/duahscHAy2d/f00D4eS8cVCZrJSEorADB2PjudSzAWLMSCoBYR2HMtSoWu/EIkipxfiESR8wuRKHJ+IRJFzi9Eosj5hUiUzgb2NDKMT4Xz1nkkd56R6IxCief964lIZfkct5VQorY5hOWmeuQzdHJmmtpmp7mtbFzO63Me9JMnL61Y5nkL56bmqO21k6ep7fjIWWpbtyacF3DH9u20z8bBDfx463kwViEXqbJEZMDFBu+wgkgAzxe40His+k48h9/SpUpd+YVIFDm/EIki5xciUeT8QiSKnF+IRJHzC5EoS5L6zOwYgEkADQB1dx+OPT9zx2w1LGsUi7GpkKinBo9Uc3Cb5SNllSIKSrUWlsRqkan39/RR2+TEDLVNVHkpr0okQqxUCkuV/SX+wvJ5Lm9O1yu8XyQCsnLhcrB9fJxHb/b2cTlyaGgrtd2wew+19ZXCsmiZrBMQzydZi6TVc3DJMRZ5yGTAmBrJJMdYrsP5LIfOf7e7X1iG4wghOoi+9guRKEt1fgfwQzN7xsz2LceEhBCdYalf++9w9zNmtgnA42b2irs/efUTWh8K+wCgq3fNEocTQiwXS7ryu/uZ1v+jAB4DcFvgOfvdfdjdh0tdfENHCNFZFu38ZtZrZv1XHgP4dQCHl2tiQoiVZSlf+zcDeKxVhqgA4O/c/X/GOmTumK2E5bJKjX8OsVJHXZFyUbGYp0gAYbT0E7NNR5KPdnXzwcrFSCLOGu83V+EyYN1IFFvkdZUiUXHxywM/ZqEQPmZsHpMzfB0vv3qE2i5c5GJTf1c4unD7Nh5duD4SQViKREfG6o1ldZ7ktU5UwFi0aMPDcnVHpD53fx3AzYvtL4RYXST1CZEocn4hEkXOL0SiyPmFSBQ5vxCJ0tEEnu6OKolusgaPemJ1ybJc+7LGmyhHEi3m+edhlgvLNYXIKtYi0XmlApcq+7p51NlMlSfcrCM8x0hZQ1Tq3FiOJDvNR6LYnFxXallE8iIJUgEgl+Pn5eylUWo7UwnXZTx6/ATts3FjuM4gAGzduoPa+vr6qa2rHJGlidRa84jUR2oXNq4hsaeu/EIkipxfiESR8wuRKHJ+IRJFzi9EonR2tx9APZLLjNEgO8RzU5O0TyGyBd+IiASFXJXaWEBQscgPWIgtcSQXXyyZYF+kTFmdfJxH0u2hFplHvcHXI2f8oE6iVRqRHf1GPpa0jptiue7MwmtVjyTjmzgzRm3HR45RW7nEd/R7enqojQWoxfIMFovh11Wt8LyQ89GVX4hEkfMLkShyfiESRc4vRKLI+YVIFDm/EInS8cCeSi0sHbE8fQCQkWAFVuYIAOqRPHezETmkGJHR8kTaKhd4Hyc59QDAPFLeKSK/ecZ1LxbXMdPgATVV8LFykfx+1cg5KxJd1HN8rFqOv66YnJfLR3IQWjgIKhInFM3/mEU00+osz0E4MR3RKpmcWuHHY/4yOzPBx5mHrvxCJIqcX4hEkfMLkShyfiESRc4vRKLI+YVIlAWlPjN7BMBHAYy6+7tbbQMAvg1gF4BjAD7h7jwUqkWWZZiZC0svhZj2kpFpRuSw2elz1FYqcTFnYDMv49RN1JpcREbLR3Lxea5GbZfHwrnnAGB2iss5O3ffGGyfrPXSPmNjl6mtXObRaDUi2wKAkTC8LKbZ8WWM9mtEDllCeI1z+UguwUiptEYsPDIW5ViZprZs/GSw/eLp1/lYJL9fLSI3zqedK/9fA7h3XtuDAJ5w970Anmj9LYR4G7Gg87v7kwAuzWu+D8CjrcePAvjYMs9LCLHCLPY3/2Z3HwGA1v+blm9KQohOsOK395rZPgD7AKDQxX93CiE6y2Kv/OfMbAgAWv/Tqgnuvt/dh919OF8qL3I4IcRys1jnPwDggdbjBwB8f3mmI4ToFO1Ifd8EcBeAQTM7BeDzAL4I4Dtm9hkAJwB8vJ3BHI5GnUgsEblmfbk72L6ml8tQsz2Rl2ZcoipO8WjALpIdc9MmvuUx182TOlbrXOrr7uKvLd8TXg8A6FmzJti+rneI9tkyWKG2WHThXER+myH9zp7nEmxtepzais7XqlDn5cvyWfhc12qR5K95vvYZ+PnMIqXNMMvHmzhzLNheGeNrNTUVPmd1kjg1xILO7+6fIqZ72h5FCHHdoTv8hEgUOb8QiSLnFyJR5PxCJIqcX4hE6WgCT7gD9bD0srann3ZbR2S70yMnaJ/ZyA1FlUgUnp09Tm27N4QlvU07ttE+r5w5Q22e8eixnmkuOa7t5XLTiydfCLb3beFRZX1lnoD0jZ+9TG2N3vXUtm7ve8NjbX0n7TN9/Ai15SORjGucR7LNTIXlw5lJel8aSsU+apuY48lCu9dtpLYN3fxcT5HIQ0RqShqLgo0kjJ2PrvxCJIqcX4hEkfMLkShyfiESRc4vRKLI+YVIlI5LfblGWNbY0sfllXNjYVmm1s+1kEI/lw5zxuWaeo3nId1567uC7WORWnfV9ZHoPOPLn1vD5bzxCR4hNjkXlgizGR4xV5nj0ufayDxOTnGJbfp8OAHpznXraJ+tN4blQQAYf5lH7k2f5vLs2LmwbWKaJ0htkOhNALg8y99z3eu51Ne/g9vqpL7e3CyPtmQ1FC2mD84/RtvPFEL8XCHnFyJR5PxCJIqcX4hEkfMLkSgd3e0v5PMYWBPehR/s47vz45fCucwGunhASrnIdz3rNb67vemGcLkrANgztCPY/tIJXlZpXZmX66pHyl1t2sJ3xXODXBmZLoQ/z3P9fB5j589S285NvHzZTInPf6wRDiS6NHae9skNvYPatt90O7WdPvUKtc3NzgTbi3n+/vBI/a98xnMJVsZ5sNB5cIWmPhOeYy7Pr80NUjruWtCVX4hEkfMLkShyfiESRc4vRKLI+YVIFDm/EInSTrmuRwB8FMCou7+71fYFAL8J4Ipu85C7/2ChY5WKeezcMhC0/Zvf+BDtd/z1XcH2yTkeWFKZ4zJUvcKlvl1budzkWVgC8sEttM/liJw3PcPnv32QlwCrOw8kmpoOB8B4F89p2Oc8F18+45rS5rW8bNj0aFjSmzodlrUAoFbhr6t3M5cct77rg9SW1S4H20fPvEb7zExxWQ6R9VjTywPGCuA5GZ14YW2Gj+UkgMcjJdTm086V/68B3Bto/4q739L6t6DjCyGuLxZ0fnd/EsClDsxFCNFBlvKb/7NmdsjMHjEz/r1RCHFdsljn/xqAGwDcAmAEwJfYE81sn5kdNLODFZJoQgjReRbl/O5+zt0b7p4B+DqA2yLP3e/uw+4+XO7iG0RCiM6yKOc3s6Gr/rwfwOHlmY4QolO0I/V9E8BdAAbN7BSAzwO4y8xuAeAAjgH4rXYGy5tjTT4sRf3KrVxiu+1d4XJYkzM8x1nN+edarc7lkPoM/2kyOxceb3eVl+uaqXC5ZipSkqtY5KdmbIKXruraHY7em63wtfJ1g9R2+uwItb36Bi+XdtP6sFR54nxk7zjjUlmji0d99u28ldo+eMOuYPulk1zq++mzz1Db6NmfUluv8fyPqPByaXMNko8v49JnoRjuUyU5MoPHWOgJ7v6pQPPDbY8ghLgu0R1+QiSKnF+IRJHzC5Eocn4hEkXOL0SidDSBZ1avY+pSWA459Qa/VWD7tt3B9m1Dm2mfQg+XhrJImayJCxeobXw8PPcNAxton+lZLr3MzEYi/qa4NDQ5tZbabrxhT/h40xGpaZZLjhu7eTRgscJf2y/98vuD7ZdmeJ9jZ8MReABQzfGyYY1ZXsoLpITW1veG31MAsPG9v0Zt9bFwMlkAuHTkKWp74/DT1HbhtZ8F23Mlfs5yhbAMaJHktG85RtvPFEL8XCHnFyJR5PxCJIqcX4hEkfMLkShyfiESpaNSXz6Xx7ru3qBt8iKvFzdCopsGt/B6a2vz/KX19vM6eFjLJcK8hWWq/kiagrWRGoSeW1wdvyMv89p0GzeGpa2eHh41ORORFW/exSMWf3WYR9PNksjJmYgStXcHj4A8d5HLkWfO8kjBs2+cDLafiNTjm4vIxN3reCLRde8OpbpscsuNv0Jt2944FGw/9COeGvP82TeC7W48Qep8dOUXIlHk/EIkipxfiESR8wuRKHJ+IRKlo7v9xXweQwPhoBSr8oCPS+dGg+0vHDpK+zx3mOda27xtB7V98FfvpLZtG8NznxvjO6z5QkQKiOz2Fwr81LxjKy+T0N1VDLaXS/xzfk2ph9rQz+dYa/B5TJKAptkGV2iOvHqM2sYq4fJfAHDrnrDCAQBTm8Lr+MYIV5eOHOdqyguv8/fcZJmrSINr+BrftDmsqAzfyQOMnvvx48H240e5cjMfXfmFSBQ5vxCJIucXIlHk/EIkipxfiESR8wuRKObOAxwAwMx2APgbAFsAZAD2u/tXzWwAwLcB7EKzZNcn3D1SrwhY39/ndw2/J2h7zzvC5Z0AYO2GsJTzzEtcknklIhvdcfc91FYHX49/fc8Hgu3ru3ifrm4eJFIocvlndo7Lhxs38LXqKYcDp6qRcl0xLB8pexa5dlgxnHPv1eOnaJ8/+U9fobYLozx455dvD58XAPjoxz8dbPcKz/t3+OmfUNuZOpcqXxrn5bWyPM+F6LPjwfa9EZ84/eqzwfYfPXEAly9d4JO8inau/HUAv+/uvwjgdgC/Y2Y3AXgQwBPuvhfAE62/hRBvExZ0fncfcfdnW48nARwBsA3AfQAebT3tUQAfW6lJCiGWn2v6zW9muwC8D8BTADa7+wjQ/IAAwL+jCCGuO9p2fjPrA/BdAJ9zd14j+q399pnZQTM7WKm1Xz5YCLGytOX8ZlZE0/G/4e7fazWfM7Ohln0IQPAGfHff7+7D7j5cLobvOxdCdJ4Fnd/MDMDDAI64+5evMh0A8EDr8QMAvr/80xNCrBTtRPXdAeDTAF40s+dbbQ8B+CKA75jZZwCcAPDxhQ5Ua2Q4Px6WsF4p8qit/OjFYPuJkRHa58577qK2h/79H1Dbn/35f6W2//6PB4Ltv7CNl+sqlvLU1tu/htoaDZ7PbmDtALVtHAiXMItFCZZKPHIvFyltNtXgCfmqhfB15Wt/8Ve0z8uvvEht5SKf42MH/p7att9IpOW9/4r26S7z0mBrnL/mrX3UhDpZDwCYJpGOXuXy7M5t4ZyMByPrNJ8Fnd/d/xkA0w25YC6EuK7RHX5CJIqcX4hEkfMLkShyfiESRc4vRKJ0NIFnqVzGtl3vDNoamKT9arVwBFapl2srQzt4mSk3HoW3Yysvx/S/v//dYPvkWZ7IsqebR3OVuyPJPanAApQL/Gapvp7wmvR08wjCUkQe6irxOXoXf23nZ8Pn86UjL9M+H/4wF49uvuVmavv6X3L58MdP/o9g+54tPNlmqYfLsxfO8sSfL7z6M2or9vJ13LwmPJfGLJd7u0lC1rbC+Vroyi9Eosj5hUgUOb8QiSLnFyJR5PxCJIqcX4hE6ajU53DUEZYvGhmX30rlsEzVy4PiMDHFE2CeG+URhBcu8Rykp86Gowu9zpOUdJW5xFOrcSknlla1XOSnrbcclgHzBS5fdXfxKLauLi4RZnkuLJ04fy5scN7nY/ffT23vf//7qe3kSZ4U9LED/xhsf+6FnbRPY65KbWPnLlNb9eJpais0eCLXmfpUsP31sZO0T085LM9WKrO0z3x05RciUeT8QiSKnF+IRJHzC5Eocn4hEqWju/31egMXxsM75rU6L59UyIU/o7zOd8ufO3SY2t5z8y9F+vE8cqw8VbXAd/SrNb7LPjJygdrmIuWkSpF8fEUyXCzgo1jigULFiLLQcF6eamouvOs8MBjOMQgAgxt4LsTJCZ4tfsvQFmq7NBZWdn74wx/QPnNT09R28WJ4Zx4Apo1fSwuRAK88UUDWbw6XqQOATZvDr7keyf04H135hUgUOb8QiSLnFyJR5PxCJIqcX4hEkfMLkSgLSn1mtgPA3wDYAiADsN/dv2pmXwDwmwCuaCkPuTvXT9DMndewsDxkeZ5HbmomHKQzO8Vll7Pnw5IiAPzpn/05tR0/epzPoxqWUY6e5oFCHglYipXkqjW4jGYNXsYpTz7PLSL2WSRXnBsvTxXNF+fh193dy+d+8SI/Z+VISbGJy1wGrFTC8z92jAcDWURCrvHTAo8EQcUCtVgOxd4yz1E5Mx2eYxZ5v82nHZ2/DuD33f1ZM+sH8IyZPd6yfcXd/3PbowkhrhvaqdU3AmCk9XjSzI4A4KlxhRBvC67pN7+Z7QLwPgBPtZo+a2aHzOwRM+P5q4UQ1x1tO7+Z9QH4LoDPufsEgK8BuAHALWh+M/gS6bfPzA6a2cF6lSe9EEJ0lrac38yKaDr+N9z9ewDg7ufcveHuGYCvA7gt1Nfd97v7sLsPFyL3kAshOsuCzm9mBuBhAEfc/ctXtQ9d9bT7AfBIGiHEdUc7u/13APg0gBfN7PlW20MAPmVmt6CpYhwD8FsLDlYoYGDDALHy6LdZEmVViZTrykUirMbHxqltw8ZN1LZ2IBxlVY/IK5nzfHD1Gpe9GnUuscVy/2W18FxismKlwueYEckOABCJ6suR68p4JDrvX370L9R29913U9tLLx+hNvayq5Fzlo+8F7PI+yomzzYqkZ+81fBcTh7nOfzy5XBOwNo1/LRuZ7f/nxGWdKOavhDi+kZ3+AmRKHJ+IRJFzi9Eosj5hUgUOb8QiWIek3KWmbUDa/0D93wgaMsi0VKkwhfyEbGiEElyabGXHInoYhFTuTyXhupVXjYsa3CJrRGRjbLIYrHTWa9x6XBqmkdHVipcjqzVIvMn6xg7Xk83T4S6a/duajv4zLPUNj4RToQai3KM+UQjYotUIgMsGgMZJJfj76uunnAE4dzUOBqNeluD6covRKLI+YVIFDm/EIki5xciUeT8QiSKnF+IROlorT6DwSwsXxSL/HPI8kS5aHBFo1iM5A6IBapFJJkyk/QifUqRFTZ0UVtMmmvEdFEiRcXkyA2DLNISqEXm4ZGoPiZVZhmXUqenuSx69tw5atu1i8uAk9PhKLeZ2XAtwSb8DVKPyoARCTZyzti5yZEalU1b+D03OjdJ+7zlGG0/Uwjxc4WcX4hEkfMLkShyfiESRc4vRKLI+YVIlI5KfQ6De1jW8CxSS45EYMUCpWKRb1EZsMAlMSMD5mITiRwvH5FyipEEk7UaT9JIE3VGphirJ5g3vlb1BpcBmbJYjLzm7v511LbtHbxWX6w+3SyprxiTMGPvHcvz+ceiAWPHzJPFiiddDUdHXr50gfaZj678QiSKnF+IRJHzC5Eocn4hEkXOL0SiLLjbb2ZdAJ4EUG49/x/c/fNmthvAtwAMAHgWwKfdI7Wp0NxVrs6FdzDZTjoAsA3W2M5xdHc1lt8vsjvvJOAjiwSCWKS8Uy6yk17s5jbP893+cmQ3mrO4fHb1WEmxavitkEWCX2LHm6nGgoj4rvhcPbxWsfcbWCAZAI+MFQveKZW4WhHLN8noITn8YsFAb3luG8+pAPiQu9+MZjnue83sdgB/DOAr7r4XwBiAz7Q9qhBi1VnQ+b3JlfSuxdY/B/AhAP/Qan8UwMdWZIZCiBWhre8IZpZvVegdBfA4gNcAjLv7le9ppwBsW5kpCiFWgrac390b7n4LgO0AbgPwi6Gnhfqa2T4zO2hmB9nvQCFE57mm3SF3HwfwfwDcDmCdmV3ZqdgO4Azps9/dh919uBjZ9BBCdJYFnd/MNprZutbjbgAfBnAEwD8B+Letpz0A4PsrNUkhxPLTjsYwBOBRaybfywH4jrv/NzN7GcC3zOyPADwH4OF2BnRa04jLK6z0E4zLLuVymdrigTHcViyF5beYrFgAl+wakeCSeizPYCyAhMiOLOcbEJe9LBZ8VI4ELRXD3/JiY8Uku9ga14icBwC5LLzGWWSsesSWj9TkyiJSZeycLaZkHpf02i8LtqDzu/shAO8LtL+O5u9/IcTbEN3hJ0SiyPmFSBQ5vxCJIucXIlHk/EIkii1GZlj0YGbnARxv/TkIoP2EYyuH5vFmNI8383abx05339jOATvq/G8a2Oyguw+vyuCah+aheehrvxCpIucXIlFW0/n3r+LYV6N5vBnN48383M5j1X7zCyFWF33tFyJRVsX5zexeM/upmR01swdXYw6teRwzsxfN7HkzO9jBcR8xs1EzO3xV24CZPW5mr7b+X79K8/iCmZ1urcnzZvaRDsxjh5n9k5kdMbOXzOx3W+0dXZPIPDq6JmbWZWY/MbMXWvP4j6323Wb2VGs9vm1mS0uQ4e4d/Qcgj2YasD0ASgBeAHBTp+fRmssxAIOrMO6dAG4FcPiqtj8B8GDr8YMA/niV5vEFAP+uw+sxBODW1uN+AD8DcFOn1yQyj46uCZpxuX2tx0UAT6GZQOc7AD7Zav8LAL+9lHFW48p/G4Cj7v66N1N9fwvAfaswj1XD3Z8EcGle831oJkIFOpQQlcyj47j7iLs/23o8iWaymG3o8JpE5tFRvMmKJ81dDeffBuDkVX+vZvJPB/BDM3vGzPat0hyusNndR4DmmxDAplWcy2fN7FDrZ8GK//y4GjPbhWb+iKewimsybx5Ah9ekE0lzV8P5Q6lGVktyuMPdbwXwGwB+x8zuXKV5XE98DcANaNZoGAHwpU4NbGZ9AL4L4HPuPtGpcduYR8fXxJeQNLddVsP5TwHYcdXfNPnnSuPuZ1r/jwJ4DKubmeicmQ0BQOv/0dWYhLufa73xMgBfR4fWxMyKaDrcN9z9e63mjq9JaB6rtSatsa85aW67rIbzPw1gb2vnsgTgkwAOdHoSZtZrZv1XHgP4dQCH471WlANoJkIFVjEh6hVna3E/OrAm1kzs9zCAI+7+5atMHV0TNo9Or0nHkuZ2agdz3m7mR9DcSX0NwB+s0hz2oKk0vADgpU7OA8A30fz6WEPzm9BnAGwA8ASAV1v/D6zSPP4WwIsADqHpfEMdmMcH0PwKewjA861/H+n0mkTm0dE1AfBeNJPiHkLzg+Y/XPWe/QmAowD+HkB5KePoDj8hEkV3+AmRKHJ+IRJFzi9Eosj5hUgUOb8QiSLnFyJR5PxCJIqcX4hE+X8F78NzaabVgQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f5cbb8cc908>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_image(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAIABJREFUeJztnWuMnOd13/9n7rMzs1fuLne5lLmiKFuXSLJCqy7cpordOKoRQDbQBPYHQx+MMChioAbSD4IL1C7aD05R2/CHwi1dC1EKx5fGNiwUahpDuQhuAFm0I1MXShRFUeQu98K93+Y+px92GFCr5//uiEvOUn3/P4Dg7nPmed8zz7xn35nnP+ccc3cIIeJHYr8dEELsDwp+IWKKgl+ImKLgFyKmKPiFiCkKfiFiioJfiJii4Bcipij4hYgpqb1MNrNHAHwDQBLAf3f3r0Q9vtSf9eHxQtC2sV6n8xKWC44nE8ko3/jxEtyWSqa5LZEJ+5HkftQbNWqrNraoLZlucT8yTWozC89rtaLm8PUwi7hEIr4d6h4+XzIZXkMASCT4vcjA/W82uR+Nevi5tVr8NWu1ru+e2Gjya7jV4q9nqxl+bg7+vJrN8PE2V6qobJInvYPrDn4zSwL4LwB+C8AUgOfN7Cl3f4XNGR4v4D9852NB2//9qzl6rlLuA8HxQk8vnZOOuGiLBR7gB/rGqW2gZyI43t/XR+fMLFyktvNXfkVtvYc2qG3o0Ca1pbPhPyjlzRU6J5fjAZm0fmprNRvU1myuB8cHesNrCADZbA+1pRA+HgCsrlWpbXEufB1UNvhrtlUtUltUQC4vzfBjbnEf1zZWybn4+i4vha+P//3fTtM5O9nL2/6HAJxz9/PuXgPwPQCP7uF4QoguspfgPwTg0jW/T7XHhBDvAfYS/KHPFe94T2RmJ8zslJmdWlvmb32EEN1lL8E/BeDwNb9PALi880HuftLdj7v78d6B7B5OJ4S4kewl+J8HcMzMJs0sA+DTAJ66MW4JIW42173b7+4NM/s8gP+DbanvCXd/OXJSAkiSm3/hAN/dPv2LvwuOHz74IJ1TKuSprVLjMk95ne/mlvvDCkrDuGQ3MM6X+NhhbivnuPqx3uI796218M59thmWWAHAs/w515v8uaWSfFd8sPdAcLwnE3GuzRK1rW2OUdv64hq1XTz7VnA8meXSG9JcspuanqW2UpGrJhvrXKpsNNg8vlZUOXwXtXn2pPO7+9MAnt7LMYQQ+4O+4SdETFHwCxFTFPxCxBQFvxAxRcEvREzZ027/u6Veb2B6fjFoG58coPOSybAENFi8Peps1DL95nlqe3OaJ2ccGg/LXpvOJaqB1DK1NXpfpbZEMbxOAFCt88Sk9ZVwMshgiifNZCLkt94+LueV8jxJp1oPr3+twWU5NLj8tjo3TG3L5/llfPbUC8HxwmGeNHPojhFqy0Ukha2t8+dWrfDzwcLHXFi8QqfU6pXgeDMie3AnuvMLEVMU/ELEFAW/EDFFwS9ETFHwCxFTurrbX6k0cfZsuBzTkdv5bu7k+28Ljp9//Ryds7nFE4UKJb7zvV4Ol1QCgJdeezE4Xhw/RucMlXgNv0aC78xOnee7/XDu/0AmXIYsqiRULsPXfrBvlNo2Vnkiy6tnwucbKBykc0q9/F5UH+LJWJvT/Jizc+EyZJMT/Hg9Re5Ho8XXvlbh11wqw4+5vBSOia3N8I4+ABhz/10k9ujOL0RMUfALEVMU/ELEFAW/EDFFwS9ETFHwCxFTuir11WqOSxdZa6Iynbc2dCk4XktwWa6Z4ok9/QOD1Hbs/ZPUNjcfPt8mSbIAgNMvc8mukeB13foPcPkQzrvXpLNhXwYG+XMu9oTr7QHA+hrv/LQwx0uxt2rhSyvXG1Gnr8aTu16s8CSu6uAQtSVGwjX8enL8dVleWaK2mct87RtVLqfWq/wa2dgMJwQ1GlHyLCmGGdF6bSe68wsRUxT8QsQUBb8QMUXBL0RMUfALEVMU/ELElD1JfWZ2AcA6gCaAhrsfj3q8u6FRDdcrW5nn2W/1rXAdvGyBpzANHOTSlme5hDJyB69Zt9YKZ21tlLnveXA/Fhe5/FPK9FHb+EQ4Uw0A6pgPjq+2+Lk2lxaoLZfkfmxwdRal3rAU1cjwmobzm7x23tM/5mvc8nf0h/0HjmbCx0w6z+pbuMxr8dUq/JpLprjMViE1DQHAiTxXLPG1Nw/PsXdxP78ROv9vuju/eoQQtyR62y9ETNlr8DuAvzSzX5jZiRvhkBCiO+z1bf9H3P2ymY0A+KmZveruz177gPYfhRMAkCvxyi9CiO6ypzu/+/ZOi7vPA/gxgIcCjznp7sfd/Xi6p6upBEKICK47+M2sYGalqz8D+DiAl26UY0KIm8tebsWjAH5s2zJFCsCfuftfRE1IwJAlrYnqZS5FDRwMF2icnpujc9Yq09TmibPUdv+9d1LbP/7tsB+FDM9Uq29x29mzEZmMy7xVUz5PMroANDPhTMGptYt0zlCJy1DjA/yjWmkwT20Zcl/ZbHCp7I2pcAYeAJz/Gc/grK2/QW12ODxva57LeWPv40U68/0RH10T/BpOJPm8np5wTNQiJOR0IuyjWRekPnc/D+D+650vhNhfJPUJEVMU/ELEFAW/EDFFwS9ETFHwCxFTuvqtm2azhfXlcGZc7wEuAS2uzQTHc0WeRbWxGVFMscELZ776ypvUNjMdlstKpRydMzp6mNpGjnD5Z+utTWq7dIVLW/lSuP/f0HAvnTPQGyFRJaaoLZXhzzuTCGekNWq8WGirHlF8ssWzAe/6NS4DfmAybCv18OKjA8O8h+LWVoHaajX+eq4vclm6WQufL5/hkiOaJF7Uq08IsRsKfiFiioJfiJii4Bcipij4hYgp3c2xdcBa4R3dRET9s43ySnB8dJTXfEuC1z+7fJknsqw538FeWw4nWqRyPAlncZPb+kq8PVWuyJNmeocmqC2fDb+kowNjEXN4PTuAr1W9zlWTej3cDsvT/H6ztjxMbb1crMDDv8XbdWVJTcOxg7xWYyZiPc6+yJWApeUtaqus8SQuJ+pT3wHuY5MpVtrtF0LshoJfiJii4Bcipij4hYgpCn4hYoqCX4iY0lWpr9VqYWN9PWhLbvK/Q6V02M36FpdWEuC2fJYndSSMS32lgXCbrGaSJxGVa1zq25rjNdomD91DbX15LomhHtZ66qtcNhooRCSQpLmPWxWefIRUeE1aSX7JnT8XrmUHAAOjvG7hg7/Opb48jgXH681wghkAVDa57Nyo8wSdWjl8bQNANsn9zxfCtmSEAmuJsORo1rnWpzu/EDFFwS9ETFHwCxFTFPxCxBQFvxAxRcEvREzZVeozsycA/A6AeXe/tz02COD7AI4AuADg99ydF1n7h2MByWz47025wrPHNt4KSyjVBZ4pNTLOJY9CRLurVZJBCAClVFgiHBzlmsyVK/xcyWZE1laVH7OywWXMrIVrzCWSYZkSAJYW+PFSBZ65t7jOJdPyBpHSUtyPS9P8chyb4HX6ckXeeitVCUuV5TKXN73KfZw4xKXPvgjJdDaiJmOhGJ7nCX4u0vUOqYisyZ108sg/AfDIjrHHATzj7scAPNP+XQjxHmLX4Hf3ZwEs7Rh+FMCT7Z+fBPDJG+yXEOImc72f+UfdfQYA2v/zqhpCiFuSm/71XjM7AeAEAGQK3S0cJITgXO+df87MxgCg/X+4VhIAdz/p7sfd/Xg6slyUEKKbXG/wPwXgsfbPjwH4yY1xRwjRLTqR+r4L4GEAB8xsCsCXAHwFwA/M7HMALgL43c5O5zAPZ3t5hUtKw73hFk/JMs+ma6zzDLEWKXIJALUKz8xaWAjLNZ7mWWCFNG/vNDwyTm0jQ7yt1XB/xBZLPfzuKp3kraTqSZ7hthZRgHRqjrc2m50KZ78t8aQ4NKr3UVupn/sxu/AKtfVZWEbrydxN54yM30lt44dK1GYNnhG6fhcvyFprhNe/aVyC3aqGZe5c/jk6Zye7Br+7f4aYPtbxWYQQtxz6hp8QMUXBL0RMUfALEVMU/ELEFAW/EDGly736HKhXgqZMiktzxUw4My7d5O43alw6tGzYBwDoyfEsvMX5cOZhkx8Od91+mNoODU1SWyrFpbnKJl+rNMKSkiUjeiHWeAbka29epLaZFW5LkD5+rRXu+6DzLM07B/h9qrHFX4BaKiy/JesLdI4l+LkyeX6u0QPhYqEAcKD3Nmpb2wwnxFbrPGuykAoXLc1nvk/n7ER3fiFiioJfiJii4Bcipij4hYgpCn4hYoqCX4iY0lWpL5lMoLcvnGWVK/CsJ0+FZapCPy+A2WhymaTR4MUUN1Z5JlVyIyyJZVPcd5S5tIUyz9yzFO/H12zw551Nh231Ji+QuhpRetXX7qK2fH2Q2zz8vLPJQ3TO7MopajuS4pmME7l7qa2eCD/v8hbPZFytzVBba4kXErUWLyTaX+C2ViIsL6+vcbk6UxgIjnvnrfp05xcirij4hYgpCn4hYoqCX4iYouAXIqZ0PbEnWQ1vRzaN1+Ore3jHditiZ3Nrg+/opzN8Yi+p+QYA2US4Pl6m0UvnFJLvo7Zk9Si1tcqj1JZP83ZSaIb/nluT7xyPlbiPB/s/TG3lJq93uLkUTtJ5c/4tOmcg9TK19Tl/XW4b4et4ZvaN4HjCwrvlAJA2rozUqnwdK2VuKxd5bb1mJqwWrVUiagKuhBWJap2rGDvRnV+ImKLgFyKmKPiFiCkKfiFiioJfiJii4BcipnTSrusJAL8DYN7d722PfRnA7wO42kPpi+7+9K5nqwOt+bDM1sq36LRagtT9y/M6d5l0uMYZACRq/FzeqFFbqxFerpHxB+icdPP91HblMk8ISqci6hPmuSzarIUTmspl/rxyeS4pJSKukL7+MWrL9IZl0aVhvvaZApfz1io8+2iu/BK1FQ+G72+5Jpf6qhWeOJVs8hZrDl4ncXbp76ktmw63ABsc5O3LEvWwj6lU581wO7nz/wmARwLjX3f3B9r/dg98IcQtxa7B7+7PAljqgi9CiC6yl8/8nzez02b2hFnE16WEELck1xv83wRwFMADAGYAfJU90MxOmNkpMztVi6ilL4ToLtcV/O4+5+5Nd28B+BaAhyIee9Ldj7v78Uym880IIcTN5bqC38yu3eb9FAC+3SqEuCXpROr7LoCHARwwsykAXwLwsJk9AMABXADwB52cLJcp4O6JXw/amj28TVYzHa4HN9bPa+Dl+nimnbW4JHPlCm9BtbQZltiSuTvonEqFZ+CVSesyAMjlea24Wo3PK2+GaxBubvIsx2ZExl+zyWXF3lJYogKAfDEsY05f4XvHlSSX+mY2r1BbcZFnaSYHwn7U1y7QOT0JLiEP5I9QWyrDr6tGlR+zkA3L0hMHefuvNMK1ELMZLtvuZNfgd/fPBIa/3fEZhBC3JPqGnxAxRcEvRExR8AsRUxT8QsQUBb8QMaWrBTx78kXcd//DQVuij8tGiWIhON6f49JQMsulwyR4C62XX+MtoxYvzgXH35zlLb7SKS7L5Yv8S0+ZOi+O6XUuG22uhgtnNpy3L8tk+HpsbXA/zl8IF8cEgGIu7GOzxS+5jTrPPLyyvkhtR+tHqG1pOlyM8+KFM3ROusZfl/5i+BoAgPEjfdS22uASZ6s/fB0PpiPkzWw4Xra/d9cZuvMLEVMU/ELEFAW/EDFFwS9ETFHwCxFTFPxCxJSuSn3ZngLuuO9DQZuneTZSMxWWa1JJnqmWbPLjWZ5LOVsv8Qy36UthuWmpwmWoUpEXg2zM8p5wPVk+b2RwhNqGesNy08YWX6uoLMF6hctvGytr1FZphbMBE62I41UucRs5HgCstbgcaYlwxl/aeC/EV85xCbPvAD/XcorL1ekCf603iKy7uMz77k2OHg+OVxv8dd6J7vxCxBQFvxAxRcEvRExR8AsRUxT8QsSUru72J5JJ9PSFd6MbLf53qMlKo6X5DnDLebJNLiKhph5RK27u9VeC404SjwBg+OA91HbutcvUVjbeyss2eZJO6lB4d9vA69zNXLxAbZtbfEd/a4vvRidJXUDziN3o3Ao1OanjCACXZrlKMNAXfm0O3zZB51SrfO3LNf6ca1VuKw1y/yvVcDJObY3XccwirEjUG/za2Inu/ELEFAW/EDFFwS9ETFHwCxFTFPxCxBQFvxAxpZN2XYcB/CmAgwBaAE66+zfMbBDA9wEcwXbLrt9z9+XdjpcgKptHtIWqk9pujSZPSGlluOTRWudJFrbBk3QaG+H6bQPDk3RO9Qqv+bY5zyWqRkRLsfoGl98WyfmSWS5vlss8WaVc5uda3+JrlUyQSyvJX7OJSX45jozx9msRnd7gHpY4N+uzdM7kkduoLdUMt8kCgK3ay9SWSE1RW60ZlhYLRS5HtsglTJ5u2KcOHtMA8EfufheADwP4QzO7G8DjAJ5x92MAnmn/LoR4j7Br8Lv7jLv/sv3zOoAzAA4BeBTAk+2HPQngkzfLSSHEjeddfeY3syMAPgjgOQCj7j4DbP+BAMCTzIUQtxwdB7+ZFQH8EMAX3J1/EHznvBNmdsrMTq0s77olIIToEh0Fv5mlsR3433H3H7WH58xsrG0fAzAfmuvuJ939uLsf7x8YuBE+CyFuALsGv5kZgG8DOOPuX7vG9BSAx9o/PwbgJzfePSHEzaKTrL6PAPgsgBfN7IX22BcBfAXAD8zscwAuAvjd3Q7k7iiTenG1Mq+dV6mFW1A1PTwOAI2I9kgN8DpyW6tc9kpkw/JbqsCXcWWBf0JamImQf5xLYo0mz1gs9o+F51S41Neq8eNtlXmWY6UZfLMHADDSAiyV5lrUgYmw7wBwx51cTp1d5HJqhiiEluBzapv82jk48GvUhsQ4NXmRXwevvRr+ODw2zOsMFrLhFl+pxM/pnHc8drcHuPvPADDR+WMdn0kIcUuhb/gJEVMU/ELEFAW/EDFFwS9ETFHwCxFTulrA0wE0SbZaKyIbKZcJt0GqVyNaUK3MUNtSnReK7Bnqp7Z/9vF/Ghy/vMW/uXhpaZraho/ydLSWRRQ0rXNproZwEclCL5eh5i/xtarUuNR37IFBakM+/IIurvJMwP4RXjgTxgtgljd4BuTgcLiAZyMiAfXAaLjILAAMD/PXJZE4QG0r5bA0BwDD/eFjZpN8zvzlsMzdqIeLgYbQnV+ImKLgFyKmKPiFiCkKfiFiioJfiJii4BcipnRX6ms5arWwFGERrhjr49fkc9I5LqPl+sPSIQAUN7lt/Xy44Obxe4bpnKP38Gw6JHjWVq3M/y4//ywv/LmwEJbE8iX+vLbKvMdcX0SPufs+9D5qe3P+tbChxGW58dsOUtvAAM/4Kxa4jFluhLP31rciCrw6f85TCy9R22A/l/qqW1w+7MuH61zUIzJdq5Ww/613UcFTd34hYoqCX4iYouAXIqYo+IWIKQp+IWJKd3f7HWjWwjuYzQqvWZdKhXcwLcVr+JV6eZJIs8wTe6YvnqG21186Fz5X7gN0TmWQt4UqkzZkADCU5y2jEi2+VsMDdwbHs/lwggsAVCOSQfoO8ESneoP7v76+EBw/NMGVEYtov/a3f/UctaV7uP8jt4Wvt0ySq0Gzl3kyU63JE5OWNrjqMJjjbb76iuFCg40Uvzc3WuHnnIyYsxPd+YWIKQp+IWKKgl+ImKLgFyKmKPiFiCkKfiFiyq5Sn5kdBvCnAA4CaAE46e7fMLMvA/h9AFd1kS+6+9PRx3Kk0/Wgrb7B69KlMuHkmEozLCcBwOW509T26qkXqa2ULFJboZ4Ljp/5mxeC4wCQPcITWRYj5M2eo1xiOzLBa7tNzYUTPpq1Bp2TymSobZRIZQDQcp4Q1NoKH7MnwSW2N197ndr+7jne2mzibn4Zt0rh+1u6MUTnNNb4egwO83NdePMNant1lbcA+/hvhmtDHpzgcvVmIyw5WqLzGn6d6PwNAH/k7r80sxKAX5jZT9u2r7v7f+74bEKIW4ZOevXNAJhp/7xuZmcA8G8sCCHeE7yrz/xmdgTABwFc/brV583stJk9YWbhpGQhxC1Jx8FvZkUAPwTwBXdfA/BNAEcBPIDtdwZfJfNOmNkpMzu1usK/ViuE6C4dBb+ZpbEd+N9x9x8BgLvPuXvT3VsAvgXgodBcdz/p7sfd/XhfP9/EEkJ0l12D38wMwLcBnHH3r10zfm1dpU8B4PWNhBC3HJ3s9n8EwGcBvGhmVzWtLwL4jJk9gO0uXBcA/MFuB2p6Dcv1cP25WpVn6G0SFXBuhUt2l5f/ltoWZvnHj4Ppe6htyMKS41pElmB6NpyxBQCZMpffpppnqe39H+W18xZbYV+WL/OXeniMy3n3fYjfH3KFsPQJAAsL4azEK1e45FUo8jqDd901QW29E1wm9mb4umrW+XrMTvM2cJtLfF6tyqXblY1Vapu+K1z7r1AaoXNmFsJSdr3B42gnnez2/wxASKyO1PSFELc2+oafEDFFwS9ETFHwCxFTFPxCxBQFvxAxpasFPButOpY3ZoK2zTVe6LJZDksvKxs8i6pV4ZJHXw9vabS1Gi7SCQCFwbDUlyAFGAEgneNZgr113sIpMcoz9waGucTW2xfOIrz4GpcjDbyl2NIcvz9UGzyrcvRgWJq7NM1lucUFLrF5mhcLHeHLgWw2vB7bX18JU63yzLiZs2vUVkhzR+58YJLaNogMuLDMr9N0NizPmqldlxBiFxT8QsQUBb8QMUXBL0RMUfALEVMU/ELElK5Kfa1mHeX1sKRnSd4fLV0KZ0v19UTINee5VFYaDhcRBYD6AZ51ZunB4Pj44L10ztQ0lzBXX+eZXncfupvaikUu5xyeCEtii5f58zr/Cj9eeY3LgMkeLttl8mGpdXQ8vIYAMDvFpcNqi8uAcO6/ISzb9fbzQqKTR3lRqivnwlmpANAgBV4BYG0pXFgVAGZnwvJhtcnl2SHSQ9ES/PXaie78QsQUBb8QMUXBL0RMUfALEVMU/ELEFAW/EDGlq1KfNyooL70atCWzXAqpWliuyZS4tDJ2zzi11eu8YGUjy/8etlbD2Xtr81zy2ljhtvIMzzx88XlewHOol79siXQ4i/DDD3Pp88jkKLUNDvPXpXeEy2X5ofBrk0gcpHMWpnnm2/wSz7ZsZS9SG+ppMon348v0cJvxp4xSkWcDtlrr1LaxES7k2kjwAq+5XLiPX6vZea8+3fmFiCkKfiFiioJfiJii4Bcipij4hYgpu+72m1kOwLMAsu3H/7m7f8nMJgF8D8AggF8C+Ky780JrANIJw8F8+JRbpNbatpPhnWNP8b9dmQG+k15b5m2htuapCctnFsPn2oio01cdorZGOqI+XsRStpp85355LpwEtV7nx7t9MtwuCgCqdb7jvHQpvB4AkNgIL2SuyJ/z5OT91DZ6KLy7DQDLFb4Ff+VKeJe9VeNKUTLDr8X7/9ERPq+5TG0tRKg+pMWWkeseACxBkpm46++gkzt/FcBH3f1+bLfjfsTMPgzgjwF83d2PAVgG8LnOTyuE2G92DX7fZqP9a7r9zwF8FMCft8efBPDJm+KhEOKm0NFnfjNLtjv0zgP4KYA3AKy4+9X3hFMADt0cF4UQN4OOgt/dm+7+AIAJAA8BuCv0sNBcMzthZqfM7NTaBv+2mBCiu7yr3X53XwHwNwA+DKDfzK7u3k0AuEzmnHT34+5+vLcY8d1IIURX2TX4zWzYzPrbP+cB/HMAZwD8NYB/2X7YYwB+crOcFELceDpJ7BkD8KSZJbH9x+IH7v6/zOwVAN8zs/8I4O8BfHvXk3kSBxrh+mjVMd7yan4qXMtsfmqOzmn08I8YqVpEm6xpnvSTWyKyVyLiHU2DP6/CHVyyGzrK69IlI/zHfHitZs/ztWoucxlqZDJirVq8Xly+OhYcX1rltfjSTZ6gMzTKk48ODvJ6h83KdHD80jRfj3wxqlUaf60bFS7NpdIRGtxC+LWurvJrsV4JX4ve6rxd167B7+6nAXwwMH4e25//hRDvQfQNPyFiioJfiJii4Bcipij4hYgpCn4hYop5RKujG34ysysA3mr/egAA78/UPeTH25Efb+e95sf73H24kwN2NfjfdmKzU+5+fF9OLj/kh/zQ234h4oqCX4iYsp/Bf3Ifz30t8uPtyI+38/+tH/v2mV8Isb/obb8QMWVfgt/MHjGz18zsnJk9vh8+tP24YGYvmtkLZnaqi+d9wszmzeyla8YGzeynZvZ6+/9w+uPN9+PLZjbdXpMXzOwTXfDjsJn9tZmdMbOXzexft8e7uiYRfnR1TcwsZ2Y/N7Nftf349+3xSTN7rr0e3zcz3lesE9y9q/8AJLFdBux2ABkAvwJwd7f9aPtyAcCBfTjvbwB4EMBL14z9JwCPt39+HMAf75MfXwbwb7q8HmMAHmz/XAJwFsDd3V6TCD+6uibYrsFbbP+cBvActgvo/ADAp9vj/xXAv9rLefbjzv8QgHPuft63S31/D8Cj++DHvuHuzwJY2jH8KLYLoQJdKohK/Og67j7j7r9s/7yO7WIxh9DlNYnwo6v4Nje9aO5+BP8hAJeu+X0/i386gL80s1+Y2Yl98uEqo+4+A2xfhABG9tGXz5vZ6fbHgpv+8eNazOwItutHPId9XJMdfgBdXpNuFM3dj+APlTTZL8nhI+7+IIB/AeAPzew39smPW4lvAjiK7R4NMwC+2q0Tm1kRwA8BfMHd17p13g786Pqa+B6K5nbKfgT/FIDD1/xOi3/ebNz9cvv/eQA/xv5WJpozszEAaP8f0Tvo5uHuc+0LrwXgW+jSmphZGtsB9x13/1F7uOtrEvJjv9akfe53XTS3U/Yj+J8HcKy9c5kB8GkAT3XbCTMrmFnp6s8APg7gpehZN5WnsF0IFdjHgqhXg63Np9CFNTEzw3YNyDPu/rVrTF1dE+ZHt9eka0Vzu7WDuWM38xPY3kl9A8C/3Scfbse20vArAC930w8A38X228c6tt8JfQ7AEIBnALze/n9wn/z4HwBeBHAa28E31gU//gm238KeBvBC+98nur0mEX50dU0A3Iftorinsf2H5t9dc83+HMA5AP8TQHYv59E3/ISIKfqGnxAxRcEvRExR8AsRUxT8QsQUBb8QMUXeFl5xAAAAFElEQVTBL0RMUfALEVMU/ELElP8HMHi5yJpU4RwAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f5cbb8cc5c0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_image(6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to do a data conversion to get the class labels in one-hot encoded format.  This will allow us to use a softmax activation and categorical cross-entopy loss in our network.  CIFAR 10 only has 10 distinct classes so this is fairly straightforward."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], dtype=float32)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import keras\n",
    "\n",
    "y_train = keras.utils.to_categorical(y_train, 10)\n",
    "y_test = keras.utils.to_categorical(y_test, 10)\n",
    "\n",
    "y_train[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The only other pre-processing step to apply is normalizing the input data.  Since everything is an RGB value, we can keep it simple and just divide by 255."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = x_train.astype('float32')\n",
    "x_test = x_test.astype('float32')\n",
    "x_train /= 255\n",
    "x_test /= 255"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a few useful configuration items to use throughout the exercise.  The input shape variable will have a value of (32, 32, 3) corresponding to the shape of the array for each image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_shape = x_train.shape[1:]\n",
    "batch_size = 256\n",
    "n_classes = 10\n",
    "lr = 0.01"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can get started with the actual modeling part.  For a first attempt, let's do the simplest and most naive model possible.  We'll just create a straightforward fully-connected model and stick a softmax activation on at the end."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import Model\n",
    "from keras.layers import Activation, Dense, Flatten, Input\n",
    "from keras.optimizers import Adam\n",
    "\n",
    "def SimpleNet(in_shape, layers, n_classes, lr):\n",
    "    i = Input(shape=in_shape)\n",
    "    x = Flatten()(i)\n",
    "    \n",
    "    for n in range(len(layers)):\n",
    "        x = Dense(layers[n])(x)\n",
    "        x = Activation('relu')(x)\n",
    "    \n",
    "    x = Dense(n_classes)(x)\n",
    "    x = Activation('softmax')(x)\n",
    "    \n",
    "    model = Model(inputs=i, outputs=x)\n",
    "    opt = Adam(lr=lr)\n",
    "    model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the architecture is somewhat flexible in that we can define as many dense layers as we want by just passing in a list of numbers to the \"layers\" parameter (where the numbers correspond to the size of the layer).  In this case we're only going to use one layer, but this capability will be very useful later on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         (None, 32, 32, 3)         0         \n",
      "_________________________________________________________________\n",
      "flatten_1 (Flatten)          (None, 3072)              0         \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 40)                122920    \n",
      "_________________________________________________________________\n",
      "activation_1 (Activation)    (None, 40)                0         \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 10)                410       \n",
      "_________________________________________________________________\n",
      "activation_2 (Activation)    (None, 10)                0         \n",
      "=================================================================\n",
      "Total params: 123,330\n",
      "Trainable params: 123,330\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = SimpleNet(in_shape, [40], n_classes, lr)\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our last step before training is to define an image data generator. We could just train on the images as-is, but randomly applying transformations to the images will make the classifier more robust. Keras has a utility class built in for just this purpose, so we can use that randomly shift or flip the direction of the images during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "datagen = ImageDataGenerator(width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try training for 10 epochs and see what happens!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "196/196 [==============================] - 26s 134ms/step - loss: 2.4531 - acc: 0.0974 - val_loss: 2.3026 - val_acc: 0.1000\n",
      "Epoch 2/10\n",
      "196/196 [==============================] - 23s 120ms/step - loss: 2.2576 - acc: 0.1268 - val_loss: 2.1575 - val_acc: 0.1623\n",
      "Epoch 3/10\n",
      "196/196 [==============================] - 24s 120ms/step - loss: 2.1256 - acc: 0.1741 - val_loss: 2.0836 - val_acc: 0.1764\n",
      "Epoch 4/10\n",
      "196/196 [==============================] - 24s 122ms/step - loss: 2.1123 - acc: 0.1760 - val_loss: 2.0775 - val_acc: 0.1972\n",
      "Epoch 5/10\n",
      "196/196 [==============================] - 23s 119ms/step - loss: 2.0938 - acc: 0.1802 - val_loss: 2.0716 - val_acc: 0.1710\n",
      "Epoch 6/10\n",
      "196/196 [==============================] - 24s 120ms/step - loss: 2.0940 - acc: 0.1784 - val_loss: 2.0660 - val_acc: 0.1875\n",
      "Epoch 7/10\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 2.0894 - acc: 0.1822 - val_loss: 2.1032 - val_acc: 0.1765\n",
      "Epoch 8/10\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 2.0954 - acc: 0.1799 - val_loss: 2.0751 - val_acc: 0.1745\n",
      "Epoch 9/10\n",
      "196/196 [==============================] - 23s 120ms/step - loss: 2.0853 - acc: 0.1788 - val_loss: 2.0702 - val_acc: 0.1743\n",
      "Epoch 10/10\n",
      "196/196 [==============================] - 23s 120ms/step - loss: 2.0889 - acc: 0.1775 - val_loss: 2.0659 - val_acc: 0.1844\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f5cb3aa60f0>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),\n",
    "                    epochs=10, validation_data=(x_test, y_test), workers=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clearly the naive approach is not very effective.  The model is basically doing a bit better than randomly guessing.  Let's replace the dense layer with a few convolutional layers instead.  I'm not going to cover convolutional layers in-depth here, there are tons of great resources out there already to learn about them.  If you're new to the concept, I would recommend [this blog series](https://adeshpande3.github.io/A-Beginner%27s-Guide-To-Understanding-Convolutional-Neural-Networks/) as a starting point.  For our first attempt as using convolutions, we'll use a kernel size of 3 and a stride of 2 (rather than use pooling layers in between the conv layers) and a global max pooling layer to condense the output shape before going through the softmax."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.layers import Conv2D, GlobalMaxPooling2D\n",
    "\n",
    "def ConvNet(in_shape, layers, n_classes, lr):\n",
    "    i = Input(shape=in_shape)\n",
    "    \n",
    "    for n in range(len(layers)):\n",
    "        if n == 0:\n",
    "            x = Conv2D(layers[n], kernel_size=3, strides=2)(i)\n",
    "        else:\n",
    "            x = Conv2D(layers[n], kernel_size=3, strides=2)(x)\n",
    "        x = Activation('relu')(x)\n",
    "    \n",
    "    x = GlobalMaxPooling2D()(x)\n",
    "    x = Dense(n_classes)(x)\n",
    "    x = Activation('softmax')(x)\n",
    "    \n",
    "    model = Model(inputs=i, outputs=x)\n",
    "    opt = Adam(lr=lr)\n",
    "    model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This time let's try using 3 conv layers with an increasing number of filters in each layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_2 (InputLayer)         (None, 32, 32, 3)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 15, 15, 20)        560       \n",
      "_________________________________________________________________\n",
      "activation_3 (Activation)    (None, 15, 15, 20)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 7, 7, 40)          7240      \n",
      "_________________________________________________________________\n",
      "activation_4 (Activation)    (None, 7, 7, 40)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_3 (Conv2D)            (None, 3, 3, 80)          28880     \n",
      "_________________________________________________________________\n",
      "activation_5 (Activation)    (None, 3, 3, 80)          0         \n",
      "_________________________________________________________________\n",
      "global_max_pooling2d_1 (Glob (None, 80)                0         \n",
      "_________________________________________________________________\n",
      "dense_3 (Dense)              (None, 10)                810       \n",
      "_________________________________________________________________\n",
      "activation_6 (Activation)    (None, 10)                0         \n",
      "=================================================================\n",
      "Total params: 37,490\n",
      "Trainable params: 37,490\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = ConvNet(in_shape, [20, 40, 80], n_classes, lr)\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It'd worth checking your intuition and understanding of what's going on by looking at the summary output and verifying that the numbers make sense.  For instance, why does the first convolutional layer have 560 parameters?  Where does that come from?  Well, we have a kernel size of 3 which creates a 3 x 3 filter (i.e. 9 parameters), but we also have different color channels for a depth of 3, so each filter is really 3 x 3 x 3 = 27 parameters, plus 1 for the bias so 28 per filter.  We specified 20 filters in the first layer, so 28 X 20 = 560.  Try applying similar logic to the second conv layer and see if the result makes sense.\n",
    "\n",
    "Now that we've got a model, let's try training it using the exact same approach as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "196/196 [==============================] - 25s 127ms/step - loss: 1.8725 - acc: 0.3019 - val_loss: 1.7737 - val_acc: 0.3772\n",
      "Epoch 2/10\n",
      "196/196 [==============================] - 24s 120ms/step - loss: 1.6342 - acc: 0.4015 - val_loss: 1.5930 - val_acc: 0.4314\n",
      "Epoch 3/10\n",
      "196/196 [==============================] - 24s 120ms/step - loss: 1.5503 - acc: 0.4349 - val_loss: 1.5013 - val_acc: 0.4567\n",
      "Epoch 4/10\n",
      "196/196 [==============================] - 24s 122ms/step - loss: 1.4848 - acc: 0.4623 - val_loss: 1.4356 - val_acc: 0.4801\n",
      "Epoch 5/10\n",
      "196/196 [==============================] - 24s 122ms/step - loss: 1.4493 - acc: 0.4798 - val_loss: 1.3845 - val_acc: 0.4972\n",
      "Epoch 6/10\n",
      "196/196 [==============================] - 23s 119ms/step - loss: 1.4186 - acc: 0.4892 - val_loss: 1.3761 - val_acc: 0.5066\n",
      "Epoch 7/10\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 1.3999 - acc: 0.4956 - val_loss: 1.3681 - val_acc: 0.5024\n",
      "Epoch 8/10\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 1.3837 - acc: 0.5047 - val_loss: 1.4632 - val_acc: 0.4810\n",
      "Epoch 9/10\n",
      "196/196 [==============================] - 23s 120ms/step - loss: 1.3838 - acc: 0.5006 - val_loss: 1.3647 - val_acc: 0.5139\n",
      "Epoch 10/10\n",
      "196/196 [==============================] - 24s 120ms/step - loss: 1.3565 - acc: 0.5114 - val_loss: 1.3553 - val_acc: 0.5162\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f5cb36ebcc0>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),\n",
    "                    epochs=10, validation_data=(x_test, y_test), workers=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The results are a lot different this time!  The model is clearly learning and after 10 epochs we're at about 50% accuracy on the validation set.  Still, we should be able to do a lot better.  For the next attempt let's introduce a few new wrinkles.  First, we're going to add batch normalization after each conv layer.  Second, we're going to add a single conv layer at the beginning with a larger kernel size and a stride of 1 so we don't reduce the receptive field.  Third, we're going to introduce padding which will modify the shape of each conv layer output.  Finally, we're going to add a few more layers to make the model bigger.\n",
    "\n",
    "To make the model definition more modular, I've pulled out the conv layer into a separate class.  There are multiple ways to do this (a function would have worked just as well) but I opted to mimic the way Keras's functional API works."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.layers import BatchNormalization\n",
    "\n",
    "class ConvLayer:\n",
    "    def __init__(self, filters, kernel_size, stride):\n",
    "        self.filters = filters\n",
    "        self.kernel_size = kernel_size\n",
    "        self.stride = stride\n",
    "\n",
    "    def __call__(self, x):\n",
    "        x = Conv2D(self.filters, kernel_size=self.kernel_size,\n",
    "                   strides=self.stride, padding='same', use_bias=False)(x)\n",
    "        x = Activation('relu')(x)\n",
    "        x = BatchNormalization()(x)\n",
    "        return x\n",
    "\n",
    "def ConvNet2(in_shape, layers, n_classes, lr):\n",
    "    i = Input(shape=in_shape)\n",
    "    \n",
    "    x = Conv2D(layers[0], kernel_size=5, strides=1, padding='same')(i)\n",
    "    x = Activation('relu')(x)\n",
    "    \n",
    "    for n in range(1, len(layers)):\n",
    "        x = ConvLayer(layers[n], kernel_size=3, stride=2)(x)\n",
    "\n",
    "    x = GlobalMaxPooling2D()(x)\n",
    "    x = Dense(n_classes)(x)\n",
    "    x = Activation('softmax')(x)\n",
    "    \n",
    "    model = Model(inputs=i, outputs=x)\n",
    "    opt = Adam(lr=lr)\n",
    "    model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_3 (InputLayer)         (None, 32, 32, 3)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_4 (Conv2D)            (None, 32, 32, 10)        760       \n",
      "_________________________________________________________________\n",
      "activation_7 (Activation)    (None, 32, 32, 10)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_5 (Conv2D)            (None, 16, 16, 20)        1800      \n",
      "_________________________________________________________________\n",
      "activation_8 (Activation)    (None, 16, 16, 20)        0         \n",
      "_________________________________________________________________\n",
      "batch_normalization_1 (Batch (None, 16, 16, 20)        80        \n",
      "_________________________________________________________________\n",
      "conv2d_6 (Conv2D)            (None, 8, 8, 40)          7200      \n",
      "_________________________________________________________________\n",
      "activation_9 (Activation)    (None, 8, 8, 40)          0         \n",
      "_________________________________________________________________\n",
      "batch_normalization_2 (Batch (None, 8, 8, 40)          160       \n",
      "_________________________________________________________________\n",
      "conv2d_7 (Conv2D)            (None, 4, 4, 80)          28800     \n",
      "_________________________________________________________________\n",
      "activation_10 (Activation)   (None, 4, 4, 80)          0         \n",
      "_________________________________________________________________\n",
      "batch_normalization_3 (Batch (None, 4, 4, 80)          320       \n",
      "_________________________________________________________________\n",
      "conv2d_8 (Conv2D)            (None, 2, 2, 160)         115200    \n",
      "_________________________________________________________________\n",
      "activation_11 (Activation)   (None, 2, 2, 160)         0         \n",
      "_________________________________________________________________\n",
      "batch_normalization_4 (Batch (None, 2, 2, 160)         640       \n",
      "_________________________________________________________________\n",
      "global_max_pooling2d_2 (Glob (None, 160)               0         \n",
      "_________________________________________________________________\n",
      "dense_4 (Dense)              (None, 10)                1610      \n",
      "_________________________________________________________________\n",
      "activation_12 (Activation)   (None, 10)                0         \n",
      "=================================================================\n",
      "Total params: 156,570\n",
      "Trainable params: 155,970\n",
      "Non-trainable params: 600\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = ConvNet2(in_shape, [10, 20, 40, 80, 160], n_classes, lr)\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We made a bunch of improvements and the network has a much larger capacity, so let's see what it does."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "196/196 [==============================] - 24s 125ms/step - loss: 1.6451 - acc: 0.4258 - val_loss: 1.5408 - val_acc: 0.4597\n",
      "Epoch 2/10\n",
      "196/196 [==============================] - 23s 118ms/step - loss: 1.3130 - acc: 0.5280 - val_loss: 1.7158 - val_acc: 0.4559\n",
      "Epoch 3/10\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 1.1669 - acc: 0.5803 - val_loss: 1.5101 - val_acc: 0.5311\n",
      "Epoch 4/10\n",
      "196/196 [==============================] - 23s 119ms/step - loss: 1.0642 - acc: 0.6205 - val_loss: 1.3304 - val_acc: 0.5538\n",
      "Epoch 5/10\n",
      "196/196 [==============================] - 23s 118ms/step - loss: 0.9887 - acc: 0.6485 - val_loss: 1.2749 - val_acc: 0.5955\n",
      "Epoch 6/10\n",
      "196/196 [==============================] - 23s 119ms/step - loss: 0.9264 - acc: 0.6717 - val_loss: 1.3210 - val_acc: 0.5819\n",
      "Epoch 7/10\n",
      "196/196 [==============================] - 23s 120ms/step - loss: 0.8812 - acc: 0.6887 - val_loss: 0.9221 - val_acc: 0.6807\n",
      "Epoch 8/10\n",
      "196/196 [==============================] - 23s 120ms/step - loss: 0.8437 - acc: 0.6985 - val_loss: 0.8809 - val_acc: 0.7012\n",
      "Epoch 9/10\n",
      "196/196 [==============================] - 24s 120ms/step - loss: 0.8196 - acc: 0.7083 - val_loss: 0.9064 - val_acc: 0.6873\n",
      "Epoch 10/10\n",
      "196/196 [==============================] - 24s 120ms/step - loss: 0.7897 - acc: 0.7194 - val_loss: 0.8259 - val_acc: 0.7179\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f5cb2d75ac8>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),\n",
    "                    epochs=10, validation_data=(x_test, y_test), workers=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's a significant improvement!  Our validation accuracy after 10 epochs jumped all the way from ~50% to ~70%.  We're already doing pretty good, but there's one more major addition we can make that should bump performance even higher.  A key addition to modern convolutional networks was the invention of [residual layers](https://towardsdatascience.com/an-overview-of-resnet-and-its-variants-5281e2f56035), which introduce an \"identity\" connection to the output of a block of convolutions.  Below I've added a new \"ResLayer\" class that inherits from \"ConvLayer\" but outputs the addition of the original input with the output from the conv layer.  Building on the previous network, we've now added two residual layers to each \"block\" in the model definition.  These residual layers have a stride of 1 so they don't change the shape of the output.  Finally, we've added a bit of regularization to keep the model from overfitting too badly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras import layers\n",
    "from keras import regularizers\n",
    "from keras.layers import Dropout\n",
    "\n",
    "class ConvLayer:\n",
    "    def __init__(self, filters, kernel_size, stride):\n",
    "        self.filters = filters\n",
    "        self.kernel_size = kernel_size\n",
    "        self.stride = stride\n",
    "\n",
    "    def __call__(self, x):\n",
    "        x = Conv2D(self.filters, kernel_size=self.kernel_size,\n",
    "                   strides=self.stride, padding='same', use_bias=False,\n",
    "                   kernel_regularizer=regularizers.l2(1e-6))(x)\n",
    "        x = Activation('relu')(x)\n",
    "        x = BatchNormalization()(x)\n",
    "        return x\n",
    "\n",
    "class ResLayer(ConvLayer):\n",
    "    def __call__(self, x):\n",
    "        return layers.add([x, super().__call__(x)])\n",
    "\n",
    "def ResNet(in_shape, layers, n_classes, lr):\n",
    "    i = Input(shape=in_shape)\n",
    "    \n",
    "    x = Conv2D(layers[0], kernel_size=7, strides=1, padding='same')(i)\n",
    "    x = Activation('relu')(x)\n",
    "\n",
    "    for n in range(1, len(layers)):\n",
    "        x = ConvLayer(layers[n], kernel_size=3, stride=2)(x)\n",
    "        x = ResLayer(layers[n], kernel_size=3, stride=1)(x)\n",
    "        x = ResLayer(layers[n], kernel_size=3, stride=1)(x)\n",
    "\n",
    "    x = GlobalMaxPooling2D()(x)\n",
    "    x = Dropout(0.1)(x)\n",
    "    x = Dense(n_classes)(x)\n",
    "    x = Activation('softmax')(x)\n",
    "    \n",
    "    model = Model(inputs=i, outputs=x)\n",
    "    opt = Adam(lr=lr)\n",
    "    model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_4 (InputLayer)            (None, 32, 32, 3)    0                                            \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_9 (Conv2D)               (None, 32, 32, 10)   1480        input_4[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "activation_13 (Activation)      (None, 32, 32, 10)   0           conv2d_9[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_10 (Conv2D)              (None, 16, 16, 20)   1800        activation_13[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "activation_14 (Activation)      (None, 16, 16, 20)   0           conv2d_10[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_5 (BatchNor (None, 16, 16, 20)   80          activation_14[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_11 (Conv2D)              (None, 16, 16, 20)   3600        batch_normalization_5[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "activation_15 (Activation)      (None, 16, 16, 20)   0           conv2d_11[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_6 (BatchNor (None, 16, 16, 20)   80          activation_15[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "add_1 (Add)                     (None, 16, 16, 20)   0           batch_normalization_5[0][0]      \n",
      "                                                                 batch_normalization_6[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_12 (Conv2D)              (None, 16, 16, 20)   3600        add_1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "activation_16 (Activation)      (None, 16, 16, 20)   0           conv2d_12[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_7 (BatchNor (None, 16, 16, 20)   80          activation_16[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "add_2 (Add)                     (None, 16, 16, 20)   0           add_1[0][0]                      \n",
      "                                                                 batch_normalization_7[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_13 (Conv2D)              (None, 8, 8, 40)     7200        add_2[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "activation_17 (Activation)      (None, 8, 8, 40)     0           conv2d_13[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_8 (BatchNor (None, 8, 8, 40)     160         activation_17[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_14 (Conv2D)              (None, 8, 8, 40)     14400       batch_normalization_8[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "activation_18 (Activation)      (None, 8, 8, 40)     0           conv2d_14[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_9 (BatchNor (None, 8, 8, 40)     160         activation_18[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "add_3 (Add)                     (None, 8, 8, 40)     0           batch_normalization_8[0][0]      \n",
      "                                                                 batch_normalization_9[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_15 (Conv2D)              (None, 8, 8, 40)     14400       add_3[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "activation_19 (Activation)      (None, 8, 8, 40)     0           conv2d_15[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_10 (BatchNo (None, 8, 8, 40)     160         activation_19[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "add_4 (Add)                     (None, 8, 8, 40)     0           add_3[0][0]                      \n",
      "                                                                 batch_normalization_10[0][0]     \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_16 (Conv2D)              (None, 4, 4, 80)     28800       add_4[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "activation_20 (Activation)      (None, 4, 4, 80)     0           conv2d_16[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_11 (BatchNo (None, 4, 4, 80)     320         activation_20[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_17 (Conv2D)              (None, 4, 4, 80)     57600       batch_normalization_11[0][0]     \n",
      "__________________________________________________________________________________________________\n",
      "activation_21 (Activation)      (None, 4, 4, 80)     0           conv2d_17[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_12 (BatchNo (None, 4, 4, 80)     320         activation_21[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "add_5 (Add)                     (None, 4, 4, 80)     0           batch_normalization_11[0][0]     \n",
      "                                                                 batch_normalization_12[0][0]     \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_18 (Conv2D)              (None, 4, 4, 80)     57600       add_5[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "activation_22 (Activation)      (None, 4, 4, 80)     0           conv2d_18[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_13 (BatchNo (None, 4, 4, 80)     320         activation_22[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "add_6 (Add)                     (None, 4, 4, 80)     0           add_5[0][0]                      \n",
      "                                                                 batch_normalization_13[0][0]     \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_19 (Conv2D)              (None, 2, 2, 160)    115200      add_6[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "activation_23 (Activation)      (None, 2, 2, 160)    0           conv2d_19[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_14 (BatchNo (None, 2, 2, 160)    640         activation_23[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_20 (Conv2D)              (None, 2, 2, 160)    230400      batch_normalization_14[0][0]     \n",
      "__________________________________________________________________________________________________\n",
      "activation_24 (Activation)      (None, 2, 2, 160)    0           conv2d_20[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_15 (BatchNo (None, 2, 2, 160)    640         activation_24[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "add_7 (Add)                     (None, 2, 2, 160)    0           batch_normalization_14[0][0]     \n",
      "                                                                 batch_normalization_15[0][0]     \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_21 (Conv2D)              (None, 2, 2, 160)    230400      add_7[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "activation_25 (Activation)      (None, 2, 2, 160)    0           conv2d_21[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_16 (BatchNo (None, 2, 2, 160)    640         activation_25[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "add_8 (Add)                     (None, 2, 2, 160)    0           add_7[0][0]                      \n",
      "                                                                 batch_normalization_16[0][0]     \n",
      "__________________________________________________________________________________________________\n",
      "global_max_pooling2d_3 (GlobalM (None, 160)          0           add_8[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 160)          0           global_max_pooling2d_3[0][0]     \n",
      "__________________________________________________________________________________________________\n",
      "dense_5 (Dense)                 (None, 10)           1610        dropout_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "activation_26 (Activation)      (None, 10)           0           dense_5[0][0]                    \n",
      "==================================================================================================\n",
      "Total params: 771,690\n",
      "Trainable params: 769,890\n",
      "Non-trainable params: 1,800\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = ResNet(in_shape, [10, 20, 40, 80, 160], n_classes, lr)\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model summary is now getting quite large, but you can still follow through each layer and make sense of what's happening. Let's run this one last time and see what the results look like. We'll increase the epoch count since deeper networks tend to take longer to train."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/40\n",
      "196/196 [==============================] - 28s 145ms/step - loss: 1.9806 - acc: 0.3498 - val_loss: 7.4266 - val_acc: 0.0771\n",
      "Epoch 2/40\n",
      "196/196 [==============================] - 23s 118ms/step - loss: 1.5761 - acc: 0.4484 - val_loss: 2.0037 - val_acc: 0.3478\n",
      "Epoch 3/40\n",
      "196/196 [==============================] - 24s 124ms/step - loss: 1.5488 - acc: 0.4612 - val_loss: 14.3443 - val_acc: 0.1005\n",
      "Epoch 4/40\n",
      "196/196 [==============================] - 24s 122ms/step - loss: 1.6194 - acc: 0.4359 - val_loss: 2.5182 - val_acc: 0.2401\n",
      "Epoch 5/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 1.5562 - acc: 0.4626 - val_loss: 2.0495 - val_acc: 0.3302\n",
      "Epoch 6/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 1.6183 - acc: 0.4400 - val_loss: 2.9989 - val_acc: 0.1782\n",
      "Epoch 7/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 1.4886 - acc: 0.4672 - val_loss: 1.3995 - val_acc: 0.4944\n",
      "Epoch 8/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 1.3551 - acc: 0.5162 - val_loss: 1.3086 - val_acc: 0.5268\n",
      "Epoch 9/40\n",
      "196/196 [==============================] - 24s 123ms/step - loss: 1.2971 - acc: 0.5373 - val_loss: 1.2979 - val_acc: 0.5423\n",
      "Epoch 10/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 1.2737 - acc: 0.5507 - val_loss: 8.2801 - val_acc: 0.1325\n",
      "Epoch 11/40\n",
      "196/196 [==============================] - 24s 123ms/step - loss: 1.3697 - acc: 0.5350 - val_loss: 1.2361 - val_acc: 0.5742\n",
      "Epoch 12/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 1.2410 - acc: 0.5652 - val_loss: 1.1365 - val_acc: 0.6007\n",
      "Epoch 13/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 1.1514 - acc: 0.5958 - val_loss: 1.1343 - val_acc: 0.6118\n",
      "Epoch 14/40\n",
      "196/196 [==============================] - 24s 122ms/step - loss: 1.1079 - acc: 0.6096 - val_loss: 1.1276 - val_acc: 0.6092\n",
      "Epoch 15/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 1.0586 - acc: 0.6306 - val_loss: 1.0696 - val_acc: 0.6330\n",
      "Epoch 16/40\n",
      "196/196 [==============================] - 23s 119ms/step - loss: 1.0240 - acc: 0.6437 - val_loss: 1.0270 - val_acc: 0.6596\n",
      "Epoch 17/40\n",
      "196/196 [==============================] - 24s 122ms/step - loss: 0.9809 - acc: 0.6611 - val_loss: 1.0828 - val_acc: 0.6391\n",
      "Epoch 18/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 0.9591 - acc: 0.6685 - val_loss: 0.9332 - val_acc: 0.6848\n",
      "Epoch 19/40\n",
      "196/196 [==============================] - 24s 122ms/step - loss: 0.9166 - acc: 0.6860 - val_loss: 0.9894 - val_acc: 0.6632\n",
      "Epoch 20/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 0.8854 - acc: 0.6983 - val_loss: 1.1848 - val_acc: 0.6169\n",
      "Epoch 21/40\n",
      "196/196 [==============================] - 24s 122ms/step - loss: 0.8659 - acc: 0.7045 - val_loss: 0.9105 - val_acc: 0.6978\n",
      "Epoch 22/40\n",
      "196/196 [==============================] - 24s 122ms/step - loss: 0.8366 - acc: 0.7162 - val_loss: 0.8779 - val_acc: 0.7132\n",
      "Epoch 23/40\n",
      "196/196 [==============================] - 23s 120ms/step - loss: 0.8175 - acc: 0.7252 - val_loss: 1.8874 - val_acc: 0.5708\n",
      "Epoch 24/40\n",
      "196/196 [==============================] - 24s 120ms/step - loss: 0.8383 - acc: 0.7203 - val_loss: 0.9611 - val_acc: 0.6878\n",
      "Epoch 25/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 0.7910 - acc: 0.7360 - val_loss: 0.8956 - val_acc: 0.7037\n",
      "Epoch 26/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 0.7728 - acc: 0.7445 - val_loss: 0.8712 - val_acc: 0.7297\n",
      "Epoch 27/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 0.7532 - acc: 0.7514 - val_loss: 0.8697 - val_acc: 0.7191\n",
      "Epoch 28/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 0.7419 - acc: 0.7568 - val_loss: 0.7995 - val_acc: 0.7405\n",
      "Epoch 29/40\n",
      "196/196 [==============================] - 24s 122ms/step - loss: 0.7385 - acc: 0.7599 - val_loss: 0.8080 - val_acc: 0.7451\n",
      "Epoch 30/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 0.7202 - acc: 0.7663 - val_loss: 0.9121 - val_acc: 0.7253\n",
      "Epoch 31/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 0.7078 - acc: 0.7737 - val_loss: 0.8999 - val_acc: 0.7223\n",
      "Epoch 32/40\n",
      "196/196 [==============================] - 24s 120ms/step - loss: 0.6969 - acc: 0.7756 - val_loss: 0.9682 - val_acc: 0.7135\n",
      "Epoch 33/40\n",
      "196/196 [==============================] - 24s 121ms/step - loss: 0.6851 - acc: 0.7825 - val_loss: 0.8145 - val_acc: 0.7456\n",
      "Epoch 34/40\n",
      "196/196 [==============================] - 23s 119ms/step - loss: 0.6800 - acc: 0.7859 - val_loss: 0.7972 - val_acc: 0.7585\n",
      "Epoch 35/40\n",
      "196/196 [==============================] - 23s 118ms/step - loss: 0.6689 - acc: 0.7919 - val_loss: 0.7807 - val_acc: 0.7654\n",
      "Epoch 36/40\n",
      "196/196 [==============================] - 24s 122ms/step - loss: 0.6626 - acc: 0.7949 - val_loss: 0.8022 - val_acc: 0.7509\n",
      "Epoch 37/40\n",
      "196/196 [==============================] - 23s 119ms/step - loss: 0.6550 - acc: 0.7987 - val_loss: 0.8129 - val_acc: 0.7613\n",
      "Epoch 38/40\n",
      "196/196 [==============================] - 24s 122ms/step - loss: 0.6532 - acc: 0.8006 - val_loss: 0.8861 - val_acc: 0.7359\n",
      "Epoch 39/40\n",
      "196/196 [==============================] - 23s 119ms/step - loss: 0.6419 - acc: 0.8043 - val_loss: 0.8233 - val_acc: 0.7568\n",
      "Epoch 40/40\n",
      "196/196 [==============================] - 24s 124ms/step - loss: 0.6308 - acc: 0.8109 - val_loss: 0.7809 - val_acc: 0.7670\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f5cad8b5c50>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),\n",
    "                    epochs=40, validation_data=(x_test, y_test), workers=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The results look pretty good.  We're starting to hit the point where accuracy improvements are getting harder to come by.  It's definitely possible to keep improving the model with the right tuning and augmentation strategies, however diminishing returns start to kick in relative to the effort involved.  Also, as the network keeps getting bigger (and as we graduate to larger and more complex data sets) it starts becoming much, much harder to build a network from scratch.\n",
    "\n",
    "Fortunately there's an alternative solution via [transfer learning](https://machinelearningmastery.com/transfer-learning-for-deep-learning/), which takes a model trained on one task and adapts it to another task.  Combined with pre-training, which is the practice of using a model that's already been trained for a given task, we can take very large networks developed by i.e. Google and Facebook and then fine-tune them to work in a custom domain of our choosing.  Below I'll walk throuh an example of how this works by using a pre-trained ImageNet model and adapting it to Kaggle's [dogs vs cats](https://www.kaggle.com/c/dogs-vs-cats) data set.\n",
    "\n",
    "First get some imports out of the way.  We'll need all of this stuff throughout the exercise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from keras.applications import ResNet50\n",
    "from keras.applications.resnet50 import preprocess_input\n",
    "from keras.layers import Dense, GlobalAveragePooling2D\n",
    "from keras.models import Model\n",
    "from keras.optimizers import RMSprop\n",
    "from keras.preprocessing import image\n",
    "from keras.preprocessing.image import ImageDataGenerator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The easiest way to get the data set is via fast.ai's servers, where they've graciously hosted a [single zip file](http://files.fast.ai/data/dogscats.zip) with everything we need.  Extract this to a directory somewhere on your machine and update the \"PATH\" variable below, and you should be good to go.  We can also specify a few useful constants such as the image dimension and batch size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = '/home/paperspace/data/dogscats/'\n",
    "train_dir = f'{PATH}train'\n",
    "valid_dir = f'{PATH}valid'\n",
    "size = 224\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we need a generator to apply transformations to the images. As before, we can use the generator Keras has built-in. The only wrinkle is using a specalized preprocessing function designed for ImageNet-like source data (this also comes with Keras and was imported above)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_datagen = ImageDataGenerator(\n",
    "    shear_range=0.2,\n",
    "    zoom_range=0.2,\n",
    "    preprocessing_function=preprocess_input,\n",
    "    horizontal_flip=True)\n",
    "\n",
    "val_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With CIFAR 10 we had the whole data set loaded into memory, but that strategy usually isn't feasible for larger image databases. In this case we have a bunch of image files in folders on disk as our starting point, and to run a model over these images we want to be able to stream images into memory in batches rather than load everything at once. Fortunately Keras can also handle this scenario natively using the \"flow_from_directory\" function. We just need to specify the directory, image size, and batch size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 23000 images belonging to 2 classes.\n",
      "Found 2000 images belonging to 2 classes.\n"
     ]
    }
   ],
   "source": [
    "train_generator = train_datagen.flow_from_directory(train_dir,\n",
    "    target_size=(size, size),\n",
    "    batch_size=batch_size, class_mode='binary')\n",
    "\n",
    "val_generator = val_datagen.flow_from_directory(valid_dir,\n",
    "    shuffle=False,\n",
    "    target_size=(size, size),\n",
    "    batch_size=batch_size, class_mode='binary')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the model, we'll use the ResNet-50 architecture with pre-trained weights.  ResNet-50 is a 168-layer architecture that achieved 92% top-5 accuracy on ImageNet classification.  Keras provides both the model architecture and an option to use existing weights out of the box.  The other notable parameter in the model initializer is \"include_top\", which indicates if we want to include the fully-connected layer at the top of the network.  In our case the answer is no, because we want to \"hook into\" the model after the last residual block and add our own architecture on top."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = ResNet50(weights='imagenet', include_top=False)\n",
    "x = base_model.output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After instantiating the pre-trained ResNet-50 model, we can start adding new layers to the architecture.  Let's start with a pooling layer to normalize the tensor shape, then add a fully-connected layer of our own.  Finally, we'll use sigmoid unit for class probability since the task is binary (cat or dog)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = GlobalAveragePooling2D()(x)\n",
    "x = Dense(1024, activation='relu')(x)\n",
    "preds = Dense(1, activation='sigmoid')(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before finishing the model definition and compiling, there's one more notable step.  We need to prevent the \"base\" layers of the model from participating in the weight update phase of training while we \"break in\" the new layers we just added.  Since each layer in a Keras model has a \"trainable\" property, we can just set it to false for all layers in the base architecture.\n",
    "\n",
    "(Aside: There is apparently some funkiness to using this approach in models that have batch norm layers that can lead to sub-optimal results, especially when doing fine-tuning which we'll get to in a few steps.  I haven't seen a conclusive answer on how to deal with this, and the niave approach seems to work okay for this problem, so I'm not doing anything special to deal with it here but I wanted to point it out as a potential issue one might run into.  There's a lengthly discussion on the subject [here](https://github.com/keras-team/keras/pull/9965))."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(inputs=base_model.input, outputs=preds)\n",
    "for layer in base_model.layers: layer.trainable = False\n",
    "model.compile(optimizer=RMSprop(lr=0.001), loss='binary_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training should be pretty familiar, the only wrinkle here is we need to specify the number of batches in an epoch when using the \"flow_from_directory\" generator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "359/359 [==============================] - 128s 357ms/step - loss: 0.1738 - acc: 0.9506 - val_loss: 0.0694 - val_acc: 0.9839\n",
      "Epoch 2/3\n",
      "359/359 [==============================] - 123s 342ms/step - loss: 0.0809 - acc: 0.9729 - val_loss: 0.1059 - val_acc: 0.9778\n",
      "Epoch 3/3\n",
      "359/359 [==============================] - 123s 344ms/step - loss: 0.0717 - acc: 0.9755 - val_loss: 0.1411 - val_acc: 0.9723\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, train_generator.n // batch_size, epochs=3, workers=4,\n",
    "                              validation_data=val_generator, validation_steps=val_generator.n // batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These results aren't too bad even with the entire base architecture held constant.  This is partly due to the fact that the training images are quite similar to the images that the architecture was trained on.  If we were fitting the model on something totally different, say medical image classification for instance, transfer learning would still work but it wouldn't be this easy.\n",
    "\n",
    "The next step is to fine-tune some of the base model by \"unfreezing\" parts of it and allowing them to update weights during training.  I'm not aware if there are any best practices for fine-tuning or not.  I think it's generally a lot of trial and error.  For this attempt, I unfroze the last residual block in the network and lowered the learning rate by an order of magnitude."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "for layer in model.layers[:142]: layer.trainable = False\n",
    "for layer in model.layers[142:]: layer.trainable = True\n",
    "model.compile(optimizer=RMSprop(lr=0.0001), loss='binary_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "359/359 [==============================] - 151s 421ms/step - loss: 0.0468 - acc: 0.9826 - val_loss: 1.0175 - val_acc: 0.9098\n",
      "Epoch 2/3\n",
      "359/359 [==============================] - 146s 406ms/step - loss: 0.0293 - acc: 0.9903 - val_loss: 0.1305 - val_acc: 0.9829\n",
      "Epoch 3/3\n",
      "359/359 [==============================] - 146s 406ms/step - loss: 0.0211 - acc: 0.9938 - val_loss: 0.1197 - val_acc: 0.9849\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, train_generator.n // batch_size, epochs=3, workers=4,\n",
    "                              validation_data=val_generator, validation_steps=val_generator.n // batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
